--- title: Using fastai for segmentation keywords: fastai sidebar: home_sidebar nb_path: "nbs/examples.fastai.segmentation.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
from pathlib import Path
from drone_detector.processing.tiling import *
import os
from fastai.vision.all import *
from drone_detector.engines.fastai.data import *
{% endraw %} {% raw %}
outpath = Path('../data/historic_map/processed/raster_tiles/')

fnames = [Path(outpath/f) for f in os.listdir(outpath)]

dls = SegmentationDataLoaders.from_label_func('../data/historic_map/', bs=16,
                                              codes=['Marshes'],
                                              fnames=fnames,
                                              label_func=partial(label_from_different_folder,
                                                                 original_folder='raster_tiles',
                                                                 new_folder='mask_tiles'),
                                              batch_tfms = [
                                                  *aug_transforms(max_rotate=0., max_warp=0.),
                                                  Normalize.from_stats(*imagenet_stats)
                                              ])
/opt/conda/lib/python3.9/site-packages/torch/_tensor.py:1142: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  ret = func(*args, **kwargs)
{% endraw %}

label_from_different_folder is a helper located in drone_detector.engines.fastai.data. That module also contains helpers to use with for instance multispectral images or time series of images.

{% raw %}
dls.show_batch(max_n=16)
{% endraw %}

Train basic U-Net, using pretrained Resnet50 as the encoder. to_fp16() tells our model to use half precision training, thus using less memory. Loss function is FocalLossFlat, and for segmentation we need to specify axis=1. Metrics are Dice and JaccardCoeff, fairly standard segmentation metrics.

{% raw %}
learn = unet_learner(dls, resnet50, pretrained=True, n_in=3, n_out=2,
                     metrics=[Dice(), JaccardCoeff()], loss_func=FocalLossFlat(axis=1)
                    ).to_fp16()
{% endraw %}

Search for a suitable learning rate.

{% raw %}
learn.lr_find()
SuggestedLRs(valley=3.630780702224001e-05)
{% endraw %}

Train the model for 2 epochs with encoder layers frozen and 10 epochs with all layers unfrozen.

{% raw %}
from fastai.callback.progress import ShowGraphCallback
learn.fine_tune(10, freeze_epochs=2, base_lr=1e-4, cbs=ShowGraphCallback)
epoch train_loss valid_loss dice jaccard_coeff time
0 0.083683 0.049321 0.087302 0.045643 00:14
1 0.085760 0.070604 0.571797 0.400361 00:11
epoch train_loss valid_loss dice jaccard_coeff time
0 0.047497 0.027717 0.719987 0.562484 00:11
1 0.037774 0.031188 0.763975 0.618090 00:11
2 0.036798 0.020224 0.824413 0.701278 00:11
3 0.032179 0.016441 0.857252 0.750167 00:11
4 0.029569 0.016763 0.864597 0.761489 00:11
5 0.025802 0.014605 0.876402 0.779996 00:11
6 0.023174 0.015947 0.870836 0.771222 00:11
7 0.021236 0.017080 0.862995 0.759007 00:11
8 0.019667 0.013650 0.885709 0.794864 00:11
9 0.018320 0.013593 0.884863 0.793501 00:11
{% endraw %}

Return to full precision.

{% raw %}
learn.to_fp32()
<fastai.learner.Learner at 0x7f1aa53597f0>
{% endraw %}

Check results.

{% raw %}
learn.show_results(max_n=8)
{% endraw %} {% raw %}
preds = learn.get_preds(with_input=False, with_decoded=False)
{% endraw %}

Export the model to use later.

{% raw %}
learn.path = Path('../data/historic_map/models')
learn.export('resnet50_focalloss_swamps.pkl')
{% endraw %}

Some helper functions for inference, such as removing all resizing transforms.

{% raw %}
def label_func(fn):
    return str(fn).replace('raster_tiles', 'mask_tiles')

@patch 
def remove(self:Pipeline, t):
    for i,o in enumerate(self.fs):
        if isinstance(o, t.__class__): self.fs.pop(i)
            
@patch
def set_base_transforms(self:DataLoader):
    attrs = ['after_item', 'after_batch']
    for i, attr in enumerate(attrs):
        tfms = getattr(self, attr)
        for j, o in enumerate(tfms):
            if hasattr(o, 'size'):
                tfms.remove(o)
            setattr(self, attr, tfms)
{% endraw %}

Load learners and remove all resizing transforms. If you run out of memory just restart the kernel.

{% raw %}
testlearn = load_learner('../data/historic_map/models/resnet50_focalloss_swamps.pkl', cpu=False)
testlearn.dls.valid.set_base_transforms()
{% endraw %}

The model is tested with 3 different map patches from different areas and sizes. Two of the images are from 1965 and two from 1984. Image sizes vary between 600x600 and 1500x1500 pixels.

{% raw %}
import PIL
def unet_predict(fn):
    image = np.array(PIL.Image.open(fn))
    mask = testlearn.predict(PILImage.create(image))[0].numpy()
    img = image
    img[:,:,0][mask==0] = 0
    img[:,:,1][mask==0] = 0
    img[:,:,2][mask==0] = 0
    img = PIL.Image.fromarray(img.astype(np.uint8))
    return img
{% endraw %} {% raw %}
test_images = [f'../data/historic_map/test_patches/{f}' for f in os.listdir('../data/historic_map/test_patches/')]
{% endraw %}

First result.

{% raw %}
patch_pred = unet_predict(test_images[0])

fig, axs = plt.subplots(1,2, figsize=(10,5),dpi=300)
for a in axs:
    a.set_yticks([])
    a.set_xticks([])
axs[0].imshow(PIL.Image.open(test_images[0]))
axs[1].imshow(patch_pred)
axs[0].set_title(test_images[0].split('/')[-1])
axs[1].set_title('Predicted marshes')
plt.show()
{% endraw %}

Second result

{% raw %}
patch_pred = unet_predict(test_images[1])

fig, axs = plt.subplots(1,2, figsize=(10,5),dpi=300)
for a in axs:
    a.set_yticks([])
    a.set_xticks([])
axs[0].imshow(PIL.Image.open(test_images[1]))
axs[1].imshow(patch_pred)
axs[0].set_title(test_images[1].split('/')[-1])
axs[1].set_title('Predicted marshes')
plt.show()
{% endraw %}

Third result

{% raw %}
patch_pred = unet_predict(test_images[3])

fig, axs = plt.subplots(1,2, figsize=(10,5),dpi=300)
for a in axs:
    a.set_yticks([])
    a.set_xticks([])
axs[0].imshow(PIL.Image.open(test_images[3]))
axs[1].imshow(patch_pred)
axs[0].set_title(test_images[3].split('/')[-1])
axs[1].set_title('Predicted marshes')
plt.show()
{% endraw %}